package com.fintech.pangu.elasticsearch.service.impl;

import com.fintech.pangu.elasticsearch.constants.RuleConstant;
import com.fintech.pangu.elasticsearch.service.AggregationService;
import com.fintech.pangu.elasticsearch.validation.NestedQueryGroup;
import com.fintech.pangu.elasticsearch.validation.ObjectQueryGroup;
import com.fintech.pangu.elasticsearch.validation.ValidatorsUtil;
import com.fintech.pangu.elasticsearch.vo.AggregationSearchReqVO;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.script.Script;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.bucket.nested.NestedAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregatorBuilders;
import org.elasticsearch.search.aggregations.pipeline.bucketselector.BucketSelectorPipelineAggregationBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;

import java.util.HashMap;
import java.util.Map;

/**
 * 常用聚合查询 接口实现类
 *
 * @author chendongdong
 * @since 1.0.0
 **/
public class AggregationServiceImpl implements AggregationService {

    private static final Logger logger = LoggerFactory.getLogger(AggregationServiceImpl.class);

    private TransportClient transportClient;
    private static final int DEFAULT_AGGREGATION_MAX_SIZE = 10000;
    /**
     * 聚合统计最大数目
     */
    private int aggMaxSize = DEFAULT_AGGREGATION_MAX_SIZE;


    public AggregationServiceImpl() {
    }

    public AggregationServiceImpl(TransportClient transportClient) {
        this.transportClient = transportClient;
    }

    public AggregationServiceImpl transportClient(TransportClient transportClient) {
        this.transportClient = transportClient;
        return this;
    }

    public AggregationServiceImpl aggMaxSize(int aggMaxSize) {
        this.aggMaxSize = aggMaxSize;
        return this;
    }

    /**
     * 获取单个字段聚合统计
     *
     * @param aggregationSearchReqVO 聚合查询请求VO类对象
     **/
    @Override
    public SearchResponse getFieldGroupBy(AggregationSearchReqVO aggregationSearchReqVO) {
        //索引
        String index = aggregationSearchReqVO.getIndex();
        //创建聚合查询构造器
        AggregationBuilder termsAggBuilder = this.buildAggregationBuilder(aggregationSearchReqVO, false, false);
        SearchResponse searchResponse = transportClient.prepareSearch(index)
                .addAggregation(termsAggBuilder)
                .get();
        return searchResponse;
    }

    /**
     * 获取单个字段聚合统计, 并支持having查询
     *
     * @param aggregationSearchReqVO 聚合查询请求VO类对象
     **/
    @Override
    public SearchResponse getFieldGroupByHaving(AggregationSearchReqVO aggregationSearchReqVO) {
        //参数有效性验证
        ValidatorsUtil.validateWithException(aggregationSearchReqVO, ObjectQueryGroup.class);
        //索引
        String index = aggregationSearchReqVO.getIndex();
        //创建聚合查询构造器
        AggregationBuilder termsAggBuilder = this.buildAggregationBuilder(aggregationSearchReqVO, false, true);
        //执行查询
        SearchResponse searchResponse = transportClient.prepareSearch(index)
                .addAggregation(termsAggBuilder)
                .get();
        return searchResponse;
    }

    /**
     * 获取单个字段聚合统计, 并支持having查询, 以及filter查询筛选
     *
     * @param aggregationSearchReqVO 聚合查询请求VO类对象
     **/
    @Override
    public SearchResponse getFieldGroupByHavingWithFilter(AggregationSearchReqVO aggregationSearchReqVO) {
        //参数有效性验证
        ValidatorsUtil.validateWithException(aggregationSearchReqVO, ObjectQueryGroup.class);
        //索引
        String index = aggregationSearchReqVO.getIndex();
        //预定义查询构造器
        QueryBuilder queryBuilder = aggregationSearchReqVO.getQueryBuilder();
        //创建聚合查询构造器
        AggregationBuilder termsAggBuilder = this.buildAggregationBuilder(aggregationSearchReqVO, false, true);
        //执行查询
        SearchRequestBuilder searchRequestBuilder = transportClient.prepareSearch(index)
                .setQuery(queryBuilder)
                .addAggregation(termsAggBuilder);
        logger.info("聚合查询构造器:\n{}", searchRequestBuilder);
        SearchResponse searchResponse = searchRequestBuilder.get();
        return searchResponse;
    }


    /**
     * 获取嵌套聚合单字段统计
     *
     * @param aggregationSearchReqVO 聚合查询请求VO类对象
     **/
    @Override
    public SearchResponse getNestedFieldGroupBy(AggregationSearchReqVO aggregationSearchReqVO) {
        //索引
        String index = aggregationSearchReqVO.getIndex();
        //创建聚合查询构造器
        AggregationBuilder nestedAggBuilder = this.buildAggregationBuilder(aggregationSearchReqVO, true, false);
        SearchResponse searchResponse = transportClient.prepareSearch(index)
                .addAggregation(nestedAggBuilder)
                .get();
        return searchResponse;
    }

    /**
     * 获取嵌套聚合单字段统计, 并支持having查询
     *
     * @param aggregationSearchReqVO 聚合查询请求VO类对象
     **/
    @Override
    public SearchResponse getNestedFieldGroupByHaving(AggregationSearchReqVO aggregationSearchReqVO) {
        //参数有效性验证
        ValidatorsUtil.validateWithException(aggregationSearchReqVO, NestedQueryGroup.class);
        //索引
        String index = aggregationSearchReqVO.getIndex();
        //创建聚合查询构造器
        AggregationBuilder nestedAggBuilder = this.buildAggregationBuilder(aggregationSearchReqVO, true, true);
        SearchResponse searchResponse = transportClient.prepareSearch(index)
                .addAggregation(nestedAggBuilder)
                .get();
        return searchResponse;
    }


    /**
     * 获取嵌套聚合单字段统计, 并支持having查询, 以及filter查询筛选
     *
     * @param aggregationSearchReqVO 聚合查询请求VO类对象
     **/
    @Override
    public SearchResponse getNestedFieldGroupByHavingWithFilter(AggregationSearchReqVO aggregationSearchReqVO) {
        //参数有效性验证
        ValidatorsUtil.validateWithException(aggregationSearchReqVO, NestedQueryGroup.class);
        //索引
        String index = aggregationSearchReqVO.getIndex();
        //预定义查询构造器
        QueryBuilder queryBuilder = aggregationSearchReqVO.getQueryBuilder();
        //创建聚合查询构造器
        AggregationBuilder nestedAggBuilder = this.buildAggregationBuilder(aggregationSearchReqVO, true, true);
        SearchRequestBuilder searchRequestBuilder = transportClient.prepareSearch(index)
                .setQuery(queryBuilder)
                .addAggregation(nestedAggBuilder);
        logger.info("聚合查询构造器:\n{}", searchRequestBuilder);
        SearchResponse searchResponse = searchRequestBuilder.get();
        return searchResponse;
    }


    /**
     * 创建聚合查询构造器
     *
     * @param aggregationSearchReqVO 聚合查询请求VO类对象
     * @param nestedFlag             嵌套结构标识(true-是;false-否)
     * @param bucketFlag             bucket筛选标识(true-进行筛选操作;false-不进行筛选操作)
     **/
    public AggregationBuilder buildAggregationBuilder(AggregationSearchReqVO aggregationSearchReqVO, boolean nestedFlag, boolean bucketFlag) {
        AggregationBuilder aggregationBuilder = null;
        //路径(类型编码)
        String path = aggregationSearchReqVO.getPath();
        //字段(类型编码.字段编码)
        String field = aggregationSearchReqVO.getField();
        //最小出现频率
        Integer minValueOccurrences = aggregationSearchReqVO.getMinValueOccurrences();
        //字段统计别名(字段_cnt)
        String fieldCountAlias = String.join(RuleConstant.UNDER_LINE, field.replace(RuleConstant.ENGLISH_FULL_STOP, RuleConstant.UNDER_LINE), RuleConstant.FIELD_COUNT_ALIAS_SUFFIX);
        TermsAggregationBuilder termsAggBuilder = AggregationBuilders.terms(fieldCountAlias).field(field).size(aggMaxSize);
        if (bucketFlag) {
            //声明BucketPath，用于后面的bucket筛选
            Map<String, String> bucketsPathsMap = new HashMap<>(1);
            bucketsPathsMap.put("orderCount", "_count");
            //设置脚本
            Script script = new Script("params.orderCount >=" + minValueOccurrences);
            //构建bucket选择器
            BucketSelectorPipelineAggregationBuilder pipelineAggBuilder =
                    PipelineAggregatorBuilders.bucketSelector("having", bucketsPathsMap, script);
            termsAggBuilder.subAggregation(pipelineAggBuilder);
        }
        if (nestedFlag) {
            //嵌套查询别名
            String name = path;
            //嵌套查询root节点
            NestedAggregationBuilder nestedAggBuilder = AggregationBuilders.nested(name, path);
            //添加子查询到root节点里面
            nestedAggBuilder.subAggregation(termsAggBuilder);
            aggregationBuilder = nestedAggBuilder;
        } else {
            aggregationBuilder = termsAggBuilder;
        }
        return aggregationBuilder;
    }

    /**
     * 获取多个字段聚合统计
     *
     * @param index       索引
     * @param firstField  第一个字段(类型编码.字段编码)
     * @param otherFields 其它字段(类型编码.字段编码)
     * @return
     **/
    @Override
    public SearchResponse getMultiFieldGroupBy(String index, String firstField, String... otherFields) {
        Assert.hasText(index, "索引[index]为空");
        Assert.hasText(firstField, "字段[firstField]为空");
        Assert.notEmpty(otherFields, "字段[otherFields]为空");
        //字段统计别名(字段.cnt)
        String firstfieldCountAlias = String.join(".", firstField, RuleConstant.FIELD_COUNT_ALIAS_SUFFIX);
        TermsAggregationBuilder termsAggBuilder = AggregationBuilders.terms(firstfieldCountAlias).field(firstField).size(aggMaxSize);
        for (String filedName : otherFields) {
            String fieldAlias = String.join(".", filedName, RuleConstant.FIELD_COUNT_ALIAS_SUFFIX);
            TermsAggregationBuilder termsOtherAggBuilder = AggregationBuilders.terms(fieldAlias).field(filedName).size(aggMaxSize);
            termsAggBuilder.subAggregation(termsOtherAggBuilder);
        }
        SearchResponse searchResponse = transportClient.prepareSearch(index)
                .addAggregation(termsAggBuilder)
                .get();
        return searchResponse;
    }
}
